{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "[download this notebook here](https://github.com/HumanCompatibleAI/imitation/blob/master/docs/tutorials/8a_train_sqil_sac.ipynb)\n",
    "# Train an Agent using Soft Q Imitation Learning with SAC\n",
    "\n",
    "In the previous tutorial, we used Soft Q Imitation Learning ([SQIL](https://arxiv.org/abs/1905.11108)) on top of the DQN base algorithm. In fact, SQIL can be combined with any off-policy algorithm from `stable_baselines3`. Here, we train a Pendulum agent using SQIL + SAC."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "First, we need some expert trajectories in our environment (`Pendulum-v1`).\n",
    "Note that you can use other environments, but the action space must be continuous."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import datasets\n",
    "from imitation.data import huggingface_utils\n",
    "\n",
    "# Download some expert trajectories from the HuggingFace Datasets Hub.\n",
    "dataset = datasets.load_dataset(\"HumanCompatibleAI/ppo-Pendulum-v1\")\n",
    "\n",
    "# Convert the dataset to a format usable by the imitation library.\n",
    "expert_trajectories = huggingface_utils.TrajectoryDatasetSequence(dataset[\"train\"])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Let's quickly check if the expert is any good."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from imitation.data import rollout\n",
    "\n",
    "trajectory_stats = rollout.rollout_stats(expert_trajectories)\n",
    "\n",
    "print(\n",
    "    f\"We have {trajectory_stats['n_traj']} trajectories. \"\n",
    "    f\"The average length of each trajectory is {trajectory_stats['len_mean']}. \"\n",
    "    f\"The average return of each trajectory is {trajectory_stats['return_mean']}.\"\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "After we collected our expert trajectories, it's time to set up our imitation algorithm."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from imitation.algorithms import sqil\n",
    "from imitation.util.util import make_vec_env\n",
    "import numpy as np\n",
    "from stable_baselines3 import sac\n",
    "\n",
    "SEED = 42\n",
    "\n",
    "venv = make_vec_env(\n",
    "    \"Pendulum-v1\",\n",
    "    rng=np.random.default_rng(seed=SEED),\n",
    ")\n",
    "\n",
    "sqil_trainer = sqil.SQIL(\n",
    "    venv=venv,\n",
    "    demonstrations=expert_trajectories,\n",
    "    policy=\"MlpPolicy\",\n",
    "    rl_algo_class=sac.SAC,\n",
    "    rl_kwargs=dict(seed=SEED),\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "As you can see the untrained policy only gets poor rewards (< 0):"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from stable_baselines3.common.evaluation import evaluate_policy\n",
    "\n",
    "reward_before_training, _ = evaluate_policy(sqil_trainer.policy, venv, 100)\n",
    "print(f\"Reward before training: {reward_before_training}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "After training, we can observe that agent is quite improved (> 1000), although it does not reach the expert performance in this case."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "sqil_trainer.train(\n",
    "    total_timesteps=1000,\n",
    ")  # Note: set to 300_000 to obtain good results\n",
    "reward_after_training, _ = evaluate_policy(sqil_trainer.policy, venv, 100)\n",
    "print(f\"Reward after training: {reward_after_training}\")"
   ]
  }
 ],
 "metadata": {
  "interpreter": {
   "hash": "bd378ce8f53beae712f05342da42c6a7612fc68b19bea03b52c7b1cdc8851b5f"
  },
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
